import argparse
import datetime
import json
import os
import pymysql

# 通过指定脚本参数将innobackup批量备份的表恢复到指定的mysql服务中 #
"""版本改动
1.为了降低与py版本耦合度,执行shell的方法不再依赖subprocess.run(shell),而使用os.system(shell)
2.参数 --import-dataDir 改为 --import
"""


def run_cmd(shell):
    time_stamp = datetime.datetime.now().strftime('%Y.%m.%d-%H:%M:%S')
    print(time_stamp + "    [shell]# " + shell)
    result = os.system(shell)
    if result != 0:
        exit()


# 测试连接mysql库的方法,默认连接mysql库
def estimate_connectMysql(host, user, passwd, port, database="mysql"):
    try:
        conn = pymysql.connect(host=host, user=user, password=passwd, port=int(port), database=database)
        conn.close()
        # print("连接mysql成功")
        return True
    except Exception as err:
        # print("连接mysql失败")
        print(err)
        return False


def query_sql(mysql_ip, mysql_user, mysql_passwd, mysql_port, databases_name, sql_n):
    db = pymysql.connect(host=mysql_ip, user=mysql_user, port=int(mysql_port), password=mysql_passwd, db=databases_name)
    cursor = db.cursor()  # 使用cursor()方法获取操作游标
    cursor.execute(sql_n)  # 执行sql语句
    data = cursor.fetchall()  # 获取数据
    new_data_n = []
    for i in data:
        # print(i)  # 返回元组
        new_data_n.append(i[0])
    cursor.close()
    db.close()
    return new_data_n


def main(config_file, target_tables_name, import_data_dir):
    """前置运行环境检查"""
    # 1-判断json文件是否存在并读取
    # jsonConfigFile_path = "config.json"
    jsonConfigFile_path = config_file
    if os.path.isfile(jsonConfigFile_path) and os.path.isfile(target_tables_name) and os.path.isdir(import_data_dir):
        with open(jsonConfigFile_path, "rb") as user_file:
            file_json = json.load(user_file)
        # print(file_json)

        # 获取json中的参数值
        mysql_ip = file_json['Mysql_ip']
        mysql_port = file_json['Mysql_port']
        mysql_user = file_json['Mysql_user']
        mysql_passwd = file_json['Mysql_passwd']
        # +++字段处理-start+++
        mysql_datadir_1 = file_json['Mysql_datadir']        # mysql服务的数据目录
        if mysql_datadir_1[-1] != "/":      # 字段处理(判断结尾是否有"/"符号)
            mysql_datadir = mysql_datadir_1 + "/"
        else:
            mysql_datadir = mysql_datadir_1
        # +++字段处理-end+++
        userGroup = file_json['userGroup']      # mysql服务的用户和用户组

        # target_tablesName_filePath = file_json['target_tablesName_filePath']
        target_tablesName_filePath = target_tables_name     # 记录 db.table 的txt文件

        # +++字段处理-start+++
        # import_dataPath_1 = file_json['import_dataPath']
        import_dataPath_1 = import_data_dir
        if import_dataPath_1[-1] != "/":
            import_dataPath = import_dataPath_1 + "/"
        else:
            import_dataPath = import_dataPath_1
        # +++字段处理-end+++
    else:
        time_stamp = datetime.datetime.now().strftime('%Y.%m.%d-%H:%M:%S')
        print(time_stamp + "    参数指定的文件或文件夹不存在.")
        exit()
    # 2-测试mysql是否可以正常连接
    connect_result = estimate_connectMysql(mysql_ip, mysql_user, mysql_passwd, mysql_port)
    if not connect_result:
        time_stamp = datetime.datetime.now().strftime('%Y.%m.%d-%H:%M:%S')
        print(time_stamp + "    mysql服务连接测试,失败.")
        exit()
    # 3-判断配置文件内的参数值是否符合要求(路径、文件是否都存在)
    if not os.path.isdir(mysql_datadir) and os.path.isfile(target_tablesName_filePath) and os.path.isdir(import_dataPath):
        time_stamp = datetime.datetime.now().strftime('%Y.%m.%d-%H:%M:%S')
        print(time_stamp + "    " + jsonConfigFile_path + " 参数值配置异常:")
        print("  |__请检查配置中的路径或文件是否存在")
        exit()

    # 4-解析target_tablesName.txt并核对内容（数据库、表是否都存在）
    # 4.1-解析target_tablesName.txt
    dic_database_table = {}  # 定义一个空的字典
    f = open(target_tablesName_filePath, 'r')  # 读取txt
    lines = f.readlines()
    for line in lines:
        new_lines = line.strip()  # 去除空行、空格
        databases_name = new_lines.split(".")[0]  # 获取这一行中的库名称
        table_name = new_lines.split(".")[1]  # 获取这一行中的表名称

        if databases_name in dic_database_table:
            # print("${database_name}在${dic_database_table}中存在,直接添加进数组.")
            dic_database_table[databases_name].append(table_name)
        else:
            # print("${database_name}在${dic_database_table}中不存在,新建数组value")
            dic_database_table[databases_name] = []
            dic_database_table[databases_name].append(table_name)
    # print(dic_database_table)

    # 4.2.1-核对库名在mysql服务中是否存在
    databases_names_not = []    # 存储检查时不存在的库名称
    for database_name in dic_database_table.keys():
        connect_full_result = estimate_connectMysql(mysql_ip, mysql_user, mysql_passwd, mysql_port, database_name)
        if not connect_full_result:
            databases_names_not.append(database_name)
    if len(databases_names_not) != 0:
        while True:
            get_sql = input("有不存在的数据库,是否创建? :(y/n)")
            if get_sql == "y":
                for name_d in databases_names_not:
                    query_sql(mysql_ip, mysql_user, mysql_passwd, mysql_port, "mysql", "CREATE DATABASE %s CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci;" % name_d)
                print("创建完成,请重新运行程序.")
                exit()
            elif get_sql == "n":
                print("退出")
                exit()
            else:
                pass
    # 4.2.2-核对表是否存在
    table_check_result = True
    for database_name in dic_database_table.keys():
        # 获取库中所有的表名称,返回数组
        sql_result = query_sql(mysql_ip, mysql_user, mysql_passwd, mysql_port, database_name, "show tables;")
        for t_name in dic_database_table[database_name]:
            if t_name not in sql_result:
                print(database_name + "." + t_name + " ---- 未查询到该表,请确认是否已导入了空表.")
                table_check_result = False
    if not table_check_result:
        exit()

    # 5-对比 innobackup 备份的数据中是否存在txt中的库名
    for database_name in dic_database_table.keys():
        if not os.path.isdir(import_dataPath + database_name):
            print(database_name + " 库在 " + import_dataPath + " 备份中不存在.")
            exit()

    time_stamp = datetime.datetime.now().strftime('%Y.%m.%d-%H:%M:%S')
    print(time_stamp + "    前置运行环境检查结果: \033[32m正常\033[0m")
    print("确认以下操作后即可开始执行")
    print("\033[32m---------\033[0m")
    print("1.备份数据回滚检测:")
    print("  [shell]# innobackupex --apply-log --export " + import_dataPath)
    print("2.临时关闭外键约束检查[全局]:")
    print("  [mysql]> SET @@global.Foreign_key_checks=0;")
    print("  [mysql]> SELECT @@global.Foreign_key_checks;")
    print("\033[32m---------\033[0m")
    while True:
        get_in = input("是否开始迁移数据(y/n): ")
        if get_in == "y":
            pass
            break  # 终止这个循环
        elif get_in == "n":
            print("退出")
            exit()
        else:
            pass

    # 开始运行迁移数据的代码
    for database_name in dic_database_table.keys():
        db = pymysql.connect(host=mysql_ip, port=int(mysql_port), user=mysql_user, passwd=mysql_passwd, db=database_name)

        tableName = dic_database_table[database_name]
        for table_name in tableName:
            # 1-按照顺序组装要执行的shell和sql语句
            sql_cmd_discard = "alter table " + database_name + "." + table_name + " discard tablespace;"  # 清空表空间的sql
            cmd_cp_frm = "/usr/bin/cp -ar " + import_dataPath + database_name + "/" + table_name + ".frm " + mysql_datadir + database_name + "/"
            cmd_cp_ibd = "/usr/bin/cp -ar " + import_dataPath + database_name + "/" + table_name + ".ibd " + mysql_datadir + database_name + "/"
            cmd_chown = "chown -R " + userGroup + " " + mysql_datadir + database_name
            sql_cmd_import = "alter table " + database_name + "." + table_name + " import tablespace;"

            time_stamp = datetime.datetime.now().strftime('%Y.%m.%d-%H:%M:%S')
            print(time_stamp + "    开始导入 [" + database_name + "." + table_name + "]")

            # 2-执行 sql_cmd_discard 的sql语句(清空表空间)
            try:
                cursor_discard = db.cursor()  # 通过cursor()方法创建游标
                cursor_discard.execute(sql_cmd_discard)  # 执行sql语句
                db.commit()  # 提交到数据库执行
                cursor_discard.close()
                # 输出执行的sql语句
                time_stamp = datetime.datetime.now().strftime('%Y.%m.%d-%H:%M:%S')
                print(time_stamp + "    [mysql]> " + sql_cmd_discard)
            except Exception as err:
                db.close()
                print("清空表空间,出现异常.")
                print(err)
                exit()      # 异常之后退出

            # 3-调用方法执行shell语句(cp复制,chown授权操作)
            run_cmd(cmd_cp_frm)
            run_cmd(cmd_cp_ibd)
            run_cmd(cmd_chown)

            # 4-执行 sql_cmd_import 的sql语句(导入表空间)
            try:
                cursor_import = db.cursor()  # 通过cursor()方法创建游标
                cursor_import.execute(sql_cmd_import)  # 执行sql语句
                db.commit()  # 提交到数据库执行
                cursor_import.close()
                # 输出执行的sql语句
                time_stamp = datetime.datetime.now().strftime('%Y.%m.%d-%H:%M:%S')
                print(time_stamp + "    [mysql]> " + sql_cmd_import)
            except Exception as err:
                db.close()
                print("导入表空间,出现异常.")
                print(err)
                exit()      # 异常之后退出

            time_stamp = datetime.datetime.now().strftime('%Y.%m.%d-%H:%M:%S')
            print(time_stamp + "    导入完成 [" + database_name + "." + table_name + "]")
            print("------")
    print("")
    time_stamp = datetime.datetime.now().strftime('%Y.%m.%d-%H:%M:%S')
    print(time_stamp + "    全部数据导入\033[32m 成功\033[0m.")
    print("请手动恢复(开启)外键约束检查[全局]:")
    print("  [mysql]> SET @@global.Foreign_key_checks=1;")
    print("  [mysql]> SELECT @@global.Foreign_key_checks;")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument('--config', type=str, dest="config", required=True, help="记录mysql相关参数的json文件")
    parser.add_argument('--txt', type=str, dest="txt", required=True, help="记录 db.table 的txt文件")
    parser.add_argument('--import', type=str, dest="importPath", required=True, help="innobackup备份好的数据目录")
    args = parser.parse_args()

    config = args.config
    tablesName_txt = args.txt
    importPath = args.importPath
    # print(config)
    # print(tablesName_txt)
    # print(importPath)

    main(config, tablesName_txt, importPath)
